// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "sharded_to_interleaved_partial_op.hpp"

#include "ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.hpp"
#include "ttnn/operations/data_movement/common/common.hpp"

using namespace tt::tt_metal;

namespace ttnn::operations::data_movement {

void ShardedToInterleavedPartialDeviceOperation::validate(const std::vector<Tensor>& input_tensors) const {
    const auto& input_tensor = input_tensors.at(0);
    const auto& output_tensor = input_tensors.at(1);
    auto shard_spec = input_tensor.shard_spec().value();

    // Validate output tensor
    TT_FATAL(
        slice_index >= 0 && slice_index < num_slices,
        "Slice index and num_slices don't match! Index = {} num_slices = {}",
        slice_index,
        num_slices);
    TT_FATAL(input_tensor.layout() == Layout::TILE, "Currently, only tile layout is supported for partial I->S");
    TT_FATAL(
        (output_tensor.physical_volume() / output_tensor.padded_shape()[-1]) % num_slices == 0,
        "Total height of a tensor must be divisible by num_slices!");

    TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to shard need to be on device!");
    TT_FATAL(input_tensor.buffer() != nullptr, "Operands to shard need to be allocated in buffers on device!");

    TT_FATAL(input_tensor.memory_config().is_sharded(), "Input tensor must be sharded");
    if (input_tensor.memory_config().memory_layout() != TensorMemoryLayout::HEIGHT_SHARDED) {
        if (input_tensor.padded_shape()[-1] % shard_spec.shape[1] != 0 ||
            ((input_tensor.physical_volume() / input_tensor.padded_shape()[-1]) % shard_spec.shape[0]) != 0) {
            TT_FATAL(
                input_tensor.shard_spec().value().grid.ranges().size() == 1,
                "Input tensor shard spec must have exactly 1 grid range but got {}",
                input_tensor.shard_spec().value().grid.ranges().size());
        }
    }
    if (input_tensor.dtype() != this->output_dtype) {
        TT_FATAL(
            input_tensor.layout() == Layout::TILE,
            "Input tensor layout must be TILE but got {}",
            input_tensor.layout());
    }
    // Divisibility of num_cores and shard size with tensor shape is done in tensor creation, so no need to assert here
}

std::vector<ttnn::TensorSpec> ShardedToInterleavedPartialDeviceOperation::compute_output_specs(
    const std::vector<Tensor>& input_tensors) const {
    // Don't create anything, we already passed in output tensor
    return {};
}
tt::tt_metal::operation::OpPerformanceModelGeneral<std::vector<Tensor>>
ShardedToInterleavedPartialDeviceOperation::create_op_performance_model(
    const std::vector<Tensor>& input_tensors,
    const std::vector<std::optional<const Tensor>>& optional_input_tensors,
    std::vector<Tensor>& output_tensors) const {
    const auto& input_tensor = input_tensors.at(0);
    const auto& output_tensor = input_tensors.at(1);
    int ideal_dev_clock_cycles = common_tm_bw_model(input_tensor, output_tensor);
    tt::tt_metal::operation::OpPerformanceModelGeneral<std::vector<Tensor>> result(
        input_tensors, output_tensors, ideal_dev_clock_cycles);
    return result;
}

operation::ProgramWithCallbacks ShardedToInterleavedPartialDeviceOperation::create_program(
    const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const {
    const auto& input_tensor = input_tensors.at(0);
    const auto& output_tensor = input_tensors[1];
    // Will move with sharded ops
    return detail::sharded_to_interleaved_multi_core(
        input_tensor, output_tensor, false, this->num_slices, this->slice_index);
}

}  // namespace ttnn::operations::data_movement
